const math = require('mathjs')
const rnorm = require('randgen').rnorm
const path = require('path')
const fs = require('fs-extra');


function weight_calc(a, b, c, d, e) {
    let t1 = math.dotMultiply(b, c)
    let t2 = math.subtract(1, c)

    let r1 = math.dotMultiply(t1, t2)
    let r2 = math.multiply(r1, math.transpose(d))
    let r3 = math.dotMultiply(a, r2)

    return math.add(e, r3)
}

// 激活函数
function sigmoid(x) {
    return 1 / (1 + Math.exp(-x))
}

// 遍历二维数组
function map(matrix, fn = sigmoid) {
    // return math.map(matrix, val => fn(val))
    return matrix.map(row => row.map(val => fn(val)))
}

// 顺序计算节点信号
function output_calc(weight, input, activation_fn) {
    // return activation_fn(np.dot(weight, input))
    // return Math.exp(math.multiply(weight, input))
    return map(math.multiply(weight, input), activation_fn)
}

function convert_matrix(arr) {
    // 将数组转换为二维数组
    const newArr = [arr]
    // 使用转置函数将二维数组转置
    return newArr[0].map((col, i) => newArr.map(row => row[i]))
}

// 定义随机权重矩阵生成函数
function randomWeightMatrix(len1, len2) {
    let matrix = math.zeros(len1, len2);
    matrix = matrix.map(() => {
        return rnorm(0, Math.pow(len1, -0.5));
        // return randomNormal(0, Math.pow(len1, -0.5));
    })
    return matrix.toArray();
}


// 定义神经网络类
class NeuralNetwork {
    /**
     * 初始化类，相当于python的__init__
     * @param {object} info_obj
     * @param {array} info_obj.node_num 节点数量 [输入层, 隐藏层, 输出层]
     * @param {float} learning_grate 学习率
     * @param {function} info_obj.activation_fns 激活函数
     * @param {function} info_obj.inverse_activation_fn 反向查询激活函数
     * @returns {number} - 两个数字的和
     */
    constructor(info_obj) {
        let init_info = {
            node_num: [784, 100, 10],
            learning_grate: 0.1,
            activation_fn: sigmoid,
            inverse_activation_fn: () => {}
        }
        try {
            if (typeof (info_obj) == 'object' && !Array.isArray(info_obj)) {
                Object.keys(info_obj).forEach(e => {
                    init_info[e] = info_obj[e]
                })
            }
        } catch (err) {
            throw new Error('变量赋值失败\n', err)
        }

        // 节点数量
        [this.ni, this.nh, this.no] = init_info.node_num
        // 学习率
        this.lr = init_info.learning_grate
        // 激活函数
        this.af = init_info.activation_fn
        this.iaf = init_info.inverse_activation_fn

        // 生成随机权重矩阵
        this.wih = randomWeightMatrix(this.nh, this.ni)
        this.who = randomWeightMatrix(this.no, this.nh)

    }

    train(input_list, target_list) {
        const input = convert_matrix(input_list),
            target = convert_matrix(target_list);

        // 最终输出结果
        const final_output = this.query(input_list)

        // 输出层误差
        const output_error = math.subtract(target, final_output)
        // 隐藏层误差
        const hidden_error = math.multiply(math.transpose(this.who), output_error)

        // 重新分配权重
        this.who = weight_calc(this.lr, output_error, final_output, this.h_o, this.who)
        this.wih = weight_calc(this.lr, hidden_error, this.h_o, input, this.wih)
    }

    query(input_list) {
        this.h_o = output_calc(this.wih, convert_matrix(input_list), this.af)
        // this.o_o = output_calc(this.who, this.h_o, this.af)
        return output_calc(this.who, this.h_o, this.af)
    }
}


// 参数
let nn_info = {
    node_num: [3, 2, 2],
    inverse_activation_fn: () => {}
}

// 实例化神经网络类
// const nn = new NeuralNetwork(nn_info)

// nn.train([1, 2, 3], [0.01, 0.99])
// console.log(nn.query([1, 2, 3]));

function normalize_arr(arr) {
    return arr.slice(1).map(val => val / 255 * 0.99 + 0.01)
}

function readFile(file_path, callback) {
    fs.readFile(file_path)
        .catch(err => console.error('数据读取失败！', err))
        .then(res => callback(res))
}

function getTarget(target_len, target) {
    let temp = math.add(math.zeros(target_len), 0.01).toArray()
    temp[target] = 0.99
    return temp
}

function train(network_info) {
    let _init = {
        node_num: [784, 200, 10],
        learning_grate: 0.1,
        train_file: path.join(__dirname, './mnist_dataset/mini/mnist_train_100.csv'),
        epochs: 1,
        rotate: true,
        rotate_info: [10, -10],
        callback_each_epoch: () => {},
        callback: () => {}
    }

    Object.keys(network_info).forEach(k => {
        _init[k] = network_info[k]
    })

    // 实例化网络
    const nn = new NeuralNetwork(_init)

    return new Promise((reject, resolve) => {
            readFile(_init.train_file, res => {
                res = res.toString().split('\n')
                // 世代
                for (let i = 0; i < _init.epochs; i++) {
                    // 条目
                    res.forEach(e => {
                        e = e.split(',')
                        const input = normalize_arr(e)
                        // 训练网络
                        nn.train(input, getTarget(_init.node_num[2], e[0]))
                    })
                    console.log(`第${i + 1}世代`);
                }
                _init.callback && _init.callback(nn)

                resolve(nn)
            })
        })
        .catch(err => console.error(err))
}

function test(n_class, test_file) {
    readFile(test_file, res => {
        res = res.toString().split('\n')
        // 添加计分板
        let scorecard = []
        res.forEach((e, i) => {
            e = e.split(',')
            const input = normalize_arr(e),
                target = e[0]

            try {
                const result = n_class.query(input).flat()
                const max = math.max(result)

                // 网络得出的实际值
                const actual = result.indexOf(max)

                if (actual == target) {
                    scorecard.push(1)
                } else {
                    scorecard.push(0)
                }
            } catch (err) {
                console.log(err);
            }
        })

        const score = math.sum(scorecard) / scorecard.length
        console.log('得分: ', score);
    })
}

network_info = {
    node_num: [784, 200, 10],
    learning_grate: 0.1,
    train_file: path.join(__dirname, './mnist_dataset/pjreddie/mnist_train.csv'),
    // train_file: path.join(__dirname, './mnist_dataset/mini/mnist_train_100.csv'),
    epochs: 1, // 只训练一次
    // rotate: true,
    // rotate_info: [10, -10],
    callback_each_epoch: () => {},
    callback: nn_class => test(nn_class, path.join(__dirname, './mnist_dataset/pjreddie/mnist_test.csv'))
}

console.time()
train(network_info)
console.timeEnd()